import torch

from .baseline.DRL.evaluator import Evaluator
from .baseline.Renderer.model import FCN
from .baseline.utils.util import *
from argparse import Namespace
from typing import Tuple

from torchbenchmark.tasks import REINFORCEMENT_LEARNING

from ...util.model import BenchmarkModel
from .baseline.DRL.ddpg import DDPG
from .baseline.DRL.multi import fastenv

torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


class Model(BenchmarkModel):
    task = REINFORCEMENT_LEARNING.OTHER_RL
    DEFAULT_TRAIN_BSIZE = 96
    DEFAULT_EVAL_BSIZE = 96
    CANNOT_SET_CUSTOM_OPTIMIZER = True

    def __init__(self, test, device, batch_size=None, extra_args=[]):
        super().__init__(
            test=test, device=device, batch_size=batch_size, extra_args=extra_args
        )
        # Train: These options are from source code.
        # Source: https://arxiv.org/pdf/1903.04411.pdf
        # Code: https://github.com/megvii-research/ICCV2019-LearningToPaint/blob/master/baseline/train.py
        self.args = Namespace(
            **{
                "validate_episodes": 5,
                "validate_interval": 50,
                "max_step": 40,
                "discount": 0.95**5,
                "episode_train_times": 10,
                "noise_factor": 0.0,
                "tau": 0.001,
                "rmsize": 800,
            }
        )

        # Train: input images are from CelebFaces and resized to 128 x 128.
        # Create 2000 random tensors for input, but randomly sample 200,000 images.
        self.width = 128
        self.image_examples = torch.rand(2000, 3, self.width, self.width)

        # LearningToPaint includes actor, critic, and discriminator models.
        self.Decoder = FCN()
        self.step = 0
        self.env = fastenv(
            max_episode_length=self.args.max_step,
            env_batch=self.batch_size,
            images=self.image_examples,
            device=self.device,
            Decoder=self.Decoder,
        )
        self.agent = DDPG(
            batch_size=self.batch_size,
            env_batch=self.batch_size,
            max_step=self.args.max_step,
            tau=self.args.tau,
            discount=self.args.discount,
            rmsize=self.args.rmsize,
            device=self.device,
            Decoder=self.Decoder,
        )
        self.evaluate = Evaluator(
            args=self.args, env_batch=self.batch_size, writer=None
        )
        self.observation = self.env.reset()
        self.agent.reset(self.observation, self.args.noise_factor)

        if test == "train":
            self.agent.train()
        elif test == "eval":
            self.agent.eval()

    def get_module(self):
        action = self.agent.select_action(
            self.observation, noise_factor=self.args.noise_factor
        )
        self.observation, reward, done, _ = self.env.step(action)
        self.agent.observe(reward, self.observation, done, self.step)
        state, action, reward, next_state, terminal = self.agent.memory.sample_batch(
            self.batch_size, self.device
        )
        state = torch.cat(
            (
                state[:, :6].float() / 255,
                state[:, 6:7].float() / self.args.max_step,
                self.agent.coord.expand(state.shape[0], 2, 128, 128),
            ),
            1,
        )
        return self.agent.actor, (state,)

    def set_module(self, new_model):
        self.agent.actor = new_model

    def train(self):
        episode = episode_steps = 0
        episode_steps += 1
        if self.observation is None:
            self.observation = self.env.reset()
            self.agent.reset(self.observation, self.args.noise_factor)
        action = self.agent.select_action(
            self.observation, noise_factor=self.args.noise_factor
        )
        self.observation, reward, done, _ = self.env.step(action)
        self.agent.observe(reward, self.observation, done, self.step)
        if episode_steps >= self.args.max_step and self.args.max_step:
            # [optional] evaluate
            if (
                episode > 0
                and self.args.validate_interval > 0
                and episode % self.args.validate_interval == 0
            ):
                reward, dist = self.evaluate(self.env, self.agent.select_action)
            tot_Q = 0.0
            tot_value_loss = 0.0
            lr = (3e-4, 1e-3)
            for i in range(self.args.episode_train_times):
                Q, value_loss = self.agent.update_policy(lr)
                tot_Q += Q.data.cpu().numpy()
                tot_value_loss += value_loss.data.cpu().numpy()
            # reset
            self.observation = None
            episode_steps = 0
            episode += 1
        self.step += 1

    def eval(self) -> Tuple[torch.Tensor]:
        reward, dist = self.evaluate(self.env, self.agent.select_action)
        return (torch.tensor(reward), torch.tensor(dist))
